{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2015 The TensorFlow Authors. All Rights Reserved.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "# ==============================================================================\n",
    "\n",
    "\"\"\"Trains and Evaluates the MNIST network using a feed dictionary.\"\"\"\n",
    "import argparse\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "import random\n",
    "import logging\n",
    "\n",
    "from six.moves import xrange  # pylint: disable=redefined-builtin\n",
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "from tensorflow.examples.tutorials.mnist import mnist\n",
    "\n",
    "from kubeflow import fairing\n",
    "fairing.config.set_builder(name='append')\n",
    "\n",
    "INPUT_DATA_DIR = '/tmp/tensorflow/mnist/input_data/'\n",
    "MAX_STEPS = 2000\n",
    "BATCH_SIZE = 100\n",
    "LEARNING_RATE = 0.3\n",
    "HIDDEN_1 = 128\n",
    "HIDDEN_2 = 32\n",
    "\n",
    "# HACK: Ideally we would want to have a unique subpath for each instance of the job, but since we can't\n",
    "# we are instead appending HOSTNAME to the logdir\n",
    "LOG_DIR = os.path.join(os.getenv('TEST_TMPDIR', '/tmp'),\n",
    "                       'tensorflow/mnist/logs/fully_connected_feed/', os.getenv('HOSTNAME', ''))\n",
    "MODEL_DIR = os.path.join(LOG_DIR, 'model.ckpt')\n",
    "\n",
    "def train():\n",
    "    data_sets = input_data.read_data_sets(INPUT_DATA_DIR)\n",
    "    images_placeholder = tf.placeholder(\n",
    "        tf.float32, shape=(BATCH_SIZE, mnist.IMAGE_PIXELS))\n",
    "    labels_placeholder = tf.placeholder(tf.int32, shape=(BATCH_SIZE))\n",
    "\n",
    "    logits = mnist.inference(images_placeholder,\n",
    "                             HIDDEN_1,\n",
    "                             HIDDEN_2)\n",
    "\n",
    "    loss = mnist.loss(logits, labels_placeholder)\n",
    "    train_op = mnist.training(loss, LEARNING_RATE)\n",
    "    summary = tf.summary.merge_all()\n",
    "    init = tf.global_variables_initializer()\n",
    "    saver = tf.train.Saver()\n",
    "    sess = tf.Session()\n",
    "    summary_writer = tf.summary.FileWriter(LOG_DIR, sess.graph)\n",
    "    sess.run(init)\n",
    "\n",
    "    data_set = data_sets.train\n",
    "    for step in xrange(MAX_STEPS):\n",
    "        images_feed, labels_feed = data_set.next_batch(BATCH_SIZE, False)\n",
    "        feed_dict = {\n",
    "            images_placeholder: images_feed,\n",
    "            labels_placeholder: labels_feed,\n",
    "        }\n",
    "\n",
    "        _, loss_value = sess.run([train_op, loss],\n",
    "                                 feed_dict=feed_dict)\n",
    "        if step % 100 == 0:\n",
    "            print(\"At step {}, loss = {}\".format(step, loss_value))\n",
    "            summary_str = sess.run(summary, feed_dict=feed_dict)\n",
    "            summary_writer.add_summary(summary_str, step)\n",
    "            summary_writer.flush()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = fairing.config.fn(train)\n",
    "train()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
